# !/usr/bin/env python3
# -*- coding:utf-8 -*-

# @Time    : 2024/5/22 16:01
# @Author  : weizjajj 
# @Email   : weizhongjie.wzj@antgroup.com
# @FileName: wenxin_langchain_instance.py


from typing import List, Optional, Any, Dict, Iterator, AsyncIterator

from langchain_community.chat_models import QianfanChatEndpoint
from langchain_community.chat_models.baidu_qianfan_endpoint import _convert_dict_to_message
from langchain_core.callbacks import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, AIMessage, AIMessageChunk
from langchain_core.outputs import ChatResult, ChatGeneration, ChatGenerationChunk

from agentuniverse.llm.llm import LLM


class WenXinLangChainInstance(QianfanChatEndpoint):
    llm: LLM = None

    def __init__(self, llm: LLM):
        init_params = {"qianfan_ak": llm.api_key, "qianfan_sk": llm.secret_key, "model": llm.model_name,
                       "max_tokens": llm.max_tokens, "timeout": llm.request_timeout,
                       'max_retries': llm.max_retries if llm.max_retries else 2,
                       'streaming': llm.streaming if llm.streaming else False,
                       'temperature': llm.temperature if llm.temperature else 0.7, 'llm': llm}
        super().__init__(**init_params)

    def _generate(
            self,
            messages: List[BaseMessage],
            stop: Optional[List[str]] = None,
            run_manager: Optional[CallbackManagerForLLMRun] = None,
            **kwargs: Any,
    ) -> ChatResult:
        """Call out to an qianfan models endpoint for each generation with a prompt.
              Args:
                  messages: The messages to pass into the model.
                  stop: Optional list of stop words to use when generating.
              Returns:
                  The string generated by the model.

              Example:
                  .. code-block:: python
                      response = qianfan_model.invoke("Tell me a joke.")
              """
        if self.streaming:
            completion = ""
            token_usage = {}
            chat_generation_info: Dict = {}
            for chunk in self._stream(messages, stop, run_manager, **kwargs):
                chat_generation_info = (
                    chunk.generation_info
                    if chunk.generation_info is not None
                    else chat_generation_info
                )
                completion += chunk.text
            lc_msg = AIMessage(content=completion, additional_kwargs={})
            gen = ChatGeneration(
                message=lc_msg,
                generation_info=dict(finish_reason="stop"),
            )
            return ChatResult(
                generations=[gen],
                llm_output={
                    "token_usage": chat_generation_info.get("usage", {}),
                    "model_name": self.model,
                },
            )
        params = self._convert_prompt_msg_params(messages, **kwargs)
        params["stop"] = stop
        response_payload = self.llm.call(**params).raw
        lc_msg = _convert_dict_to_message(response_payload)
        gen = ChatGeneration(
            message=lc_msg,
            generation_info={
                "finish_reason": "stop",
                **response_payload.get("body", {}),
            },
        )
        token_usage = response_payload.get("usage", {})
        llm_output = {"token_usage": token_usage, "model_name": self.model}
        return ChatResult(generations=[gen], llm_output=llm_output)

    async def _agenerate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        if self.streaming:
            completion = ""
            token_usage = {}
            chat_generation_info: Dict = {}
            async for chunk in self._astream(messages, stop, run_manager, **kwargs):
                chat_generation_info = (
                    chunk.generation_info
                    if chunk.generation_info is not None
                    else chat_generation_info
                )
                completion += chunk.text

            lc_msg = AIMessage(content=completion, additional_kwargs={})
            gen = ChatGeneration(
                message=lc_msg,
                generation_info=dict(finish_reason="stop"),
            )
            return ChatResult(
                generations=[gen],
                llm_output={
                    "token_usage": chat_generation_info.get("usage", {}),
                    "model_name": self.model,
                },
            )
        params = self._convert_prompt_msg_params(messages, **kwargs)
        params["stop"] = stop
        response_payload = await self.llm.acall(**params)
        response_payload = response_payload.raw
        lc_msg = _convert_dict_to_message(response_payload)
        generations = []
        gen = ChatGeneration(
            message=lc_msg,
            generation_info={
                "finish_reason": "stop",
                **response_payload.get("body", {}),
            },
        )
        generations.append(gen)
        token_usage = response_payload.get("usage", {})
        llm_output = {"token_usage": token_usage, "model_name": self.model}
        return ChatResult(generations=generations, llm_output=llm_output)

    def _stream(
            self,
            messages: List[BaseMessage],
            stop: Optional[List[str]] = None,
            run_manager: Optional[CallbackManagerForLLMRun] = None,
            **kwargs: Any,
    ) -> Iterator[ChatGenerationChunk]:
        """Call out to an qianfan models endpoint for each generation with a prompt.
              Args:
                  messages: The messages to pass into the model.
                  stop: Optional list of stop words to use when generating.
              Returns:
                  The string generated by the model.

              Example:
                  .. code-block:: python
                      response = qianfan_model.invoke("Tell me a joke.")
              """
        params = self._convert_prompt_msg_params(messages, **kwargs)
        params["stop"] = stop
        params["stream"] = True
        for res in self.llm.call(**params):
            if res:
                res = res.raw
                msg = _convert_dict_to_message(res)
                additional_kwargs = msg.additional_kwargs.get("function_call", {})
                chunk = ChatGenerationChunk(
                    text=res["result"],
                    message=AIMessageChunk(
                        content=msg.content,
                        role="assistant",
                        additional_kwargs=additional_kwargs,
                    ),
                    generation_info=msg.additional_kwargs,
                )
                if run_manager:
                    run_manager.on_llm_new_token(chunk.text, chunk=chunk)
                yield chunk

    async def _astream(
            self,
            messages: List[BaseMessage],
            stop: Optional[List[str]] = None,
            run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
            **kwargs: Any,
    ) -> AsyncIterator[ChatGenerationChunk]:
        params = self._convert_prompt_msg_params(messages, **kwargs)
        params["stop"] = stop
        params["stream"] = True
        async for res in await self.llm.acall(**params):
            res = res.raw
            if res:
                msg = _convert_dict_to_message(res)
                additional_kwargs = msg.additional_kwargs.get("function_call", {})
                chunk = ChatGenerationChunk(
                    text=res["result"],
                    message=AIMessageChunk(
                        content=msg.content,
                        role="assistant",
                        additional_kwargs=additional_kwargs,
                    ),
                    generation_info=msg.additional_kwargs,
                )
                if run_manager:
                    await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
                yield chunk
